# -*- coding: utf-8 -*-
"""
Scenario simulation for spatial–institutional misfit (baseline vs S1–S4)

改动摘要（相对上一版）：
- 降低 S3 改善、提升 S4：target_dp 调整为 S1=0.05, S2=0.05, S3=0.07, S4=0.14，并对 S3 额外施加 0.85 阻尼
- 版式：Baseline 左列跨两行（与右边两幅子图高度齐平）
- 其他保持：自动标定Δ、超额/缺口成比例调整、自适应色标与气泡缩放、Y_t 兜底 SLAG
"""

from pathlib import Path
from typing import Optional, Tuple, Dict, List, Iterable
import re
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

# ================== 路径 ==================
PANEL_CSV = Path(r"D:\Fragmentation of MPA governance\mpa_network_misfit\pythonProject1\.venv\outputs_alaam\alaam\alaam_nodes_panel__robust.csv")
COEF_CSV  = Path(r"D:\Fragmentation of MPA governance\mpa_network_misfit\pythonProject1\.venv\outputs_alaam\alaam\alaam_coefficients_pooled.csv")
COORD_CSV = Path(r"D:\Fragmentation of MPA governance\mpa_network_misfit\pythonProject1\.venv\outputs_alaam\nodes\misfit_nodes_all_years_unified.csv")

OUT_DIR   = Path(r"D:\Fragmentation of MPA governance\mpa_network_misfit\pythonProject1\.venv\outputs_alaam\alaam\scenarios")
FIG_DIR   = OUT_DIR / "figs"
for d in [OUT_DIR, FIG_DIR]:
    d.mkdir(parents=True, exist_ok=True)

# ======== 识别/兜底开关 ========
FORCE_SLAG_FROM_Y   = True   # 用 Y_t 的 Z 值强制作为 SLAG（兜底）
STRICT_FAIL_SLAG    = False
DEBUG_PRINT_COLUMNS = True

# ================== 规范列名 ==================
COL_SLAG = "spatial_lag_misfit_t1_z"
COL_CORR = "corridor_betweenness_t1_z"
COL_LOOI = "z_prov_count_looi_t1_z"
COL_MISF = "high_misfit_t1"
COL_SEA  = "sea_region"
COL_YEAR = "year"
COL_LON  = "lon"
COL_LAT  = "lat"
Y_COL    = "Y_t"

SEA_ALIASES    = ["sea_region","SeaRegion","sea","sea_region_std","Sea_Region","SEA_REGION"]
YEAR_ALIASES   = ["year","Year","YEAR"]
MISFIT_ALIASES = ["high_misfit_t1","misfit_high_t1","high_misfit","is_misfit","misfit_flag","misfit_t1","misfit"]
Y_ALIASES      = ["Y_t","Y","y","Y_2015","Outcome","outcome","label","target"]

SLAG_ALIASES = [
    "spatial_lag_misfit_t1_z","spatial_lag_misfit_z","slag_misfit_t1_z","slag_misfit_z",
    "spw_lag_misfit_t1_z","spaw_lag_misfit_t1_z","spaW_lag_misfit_t1_z",
    "spatial_lag_Y_t_z","slag_Y_t_z","spaw_Y_t_z","spw_Y_t_z"
]
SLAG_RAW_CAND = [
    "spatial_lag_misfit_t1","spatial_lag_misfit","slag_misfit_t1","slag_misfit",
    "spw_lag_misfit_t1","spaw_lag_misfit_t1","spaW_lag_misfit_t1",
    "spatial_lag_Y_t","slag_Y_t","spaw_Y_t","spw_Y_t"
]

CORR_ALIASES = [
    "corridor_betweenness_t1_z","corridor_betweenness_z","corridor_btwn_t1_z","corr_btwn_t1_z",
    "corridor_betw_t1_z","corridor_btw_t1_z"
]
CORR_RAW_CAND = [
    "corridor_betweenness_t1","corridor_betweenness","corridor_btwn_t1","corridor_btwn",
    "corridor_betw_t1","corridor_btw_t1"
]
CORR_SURROGATES = ["SC_gate","SC_deg","gate","deg"]

# ================== 情景参数（基础） ==================
S1_TOP_Q   = 0.50
DELTA_S1   = -0.90
DELTA_S2   = +0.90
DELTA_LOOI = +1.20
CLIP_MIN, CLIP_MAX = -1.0, 1.0

FILTER_YEAR = 2025
FILTER_TO_MISFIT = True

BASE_SIZE = 30.0

# ================== 图形风格 ==================
plt.rcParams.update({
    "figure.facecolor": "white",
    "axes.facecolor": "#f9f9fa",
    "axes.grid": True,
    "grid.color": "#dcdfe3",
    "grid.alpha": 0.6,
    "grid.linestyle": "--",
    "axes.edgecolor": "#c7c9cc",
    "axes.titleweight": "bold",
    "font.size": 11,
})
BASELINE_NODE_COLOR = "#C4D8E9"
BASELINE_EDGE_COLOR = "white"

# ================== 工具函数 ==================
def invlogit(x):
    x = np.clip(x, -40, 40)
    return 1.0 / (1.0 + np.exp(-x))

def _detect_id_col(df: pd.DataFrame) -> Optional[str]:
    cands = ["MPA_ID", "MPA-ID", "Id", "id", "node_id"]
    for c in cands:
        if c in df.columns: return c
    low = {c.lower(): c for c in df.columns}
    for c in ["mpa_id", "mpa-id", "id", "node_id"]:
        if c in low: return low[c]
    return None

def _detect_coord_cols(df: pd.DataFrame) -> Tuple[Optional[str], Optional[str]]:
    cand_lon = ["lon","longitude","x","Lon","LONGITUDE","LON","X","中心经度","经度"]
    cand_lat = ["lat","latitude","y","Lat","LATITUDE","LAT","Y","中心纬度","纬度"]
    lon_col = next((c for c in cand_lon if c in df.columns), None)
    lat_col = next((c for c in cand_lat if c in df.columns), None)
    return lon_col, lat_col

def _zscore(series: pd.Series) -> pd.Series:
    s = pd.to_numeric(series, errors="coerce")
    m, sd = s.mean(), s.std(ddof=0)
    if sd == 0 or pd.isna(sd): return pd.Series(0.0, index=s.index)
    return (s - m) / sd

def _first_match(cols: Iterable[str], *patterns_or_aliases: Iterable[str]) -> Optional[str]:
    colset = list(cols)
    for group in patterns_or_aliases:
        for a in group:
            if a in colset: return a
    for group in patterns_or_aliases:
        for pat in group:
            try:
                if pat.startswith("^") or pat.endswith("$") or ("(?i)" in pat) or ("[" in pat):
                    regex = re.compile(pat, flags=re.IGNORECASE)
                else:
                    regex = re.compile(re.escape(pat), flags=re.IGNORECASE)
            except re.error:
                regex = re.compile(re.escape(pat), flags=re.IGNORECASE)
            for c in colset:
                if regex.search(str(c)): return c
    return None

def parse_coefs(path: Path) -> Dict[str, float]:
    raw = pd.read_csv(path, dtype=str, keep_default_na=False)
    low = {c.lower(): c for c in raw.columns}
    tcol = next((low[x] for x in ["term","variable","name","param","terms"] if x in low), None)
    ccol = next((low[x] for x in ["coef","coefficient","estimate","est","b","beta","value"] if x in low), None)
    if tcol is None or ccol is None:
        raise KeyError("系数表需包含列：term 与 coef（或其别名）")
    df = raw[[tcol, ccol]].rename(columns={tcol:"term", ccol:"coef"}).copy()
    df["term"] = df["term"].astype(str).str.strip().replace({"(Intercept)":"Intercept"})
    vals = pd.to_numeric(df["coef"], errors="coerce")
    if vals.isna().any():
        ext = df["coef"].astype(str).str.extract(r"([-+]?\d*\.?\d+(?:[eE][-+]?\d+)?)", expand=False)
        vals = vals.fillna(pd.to_numeric(ext, errors="coerce"))
    df["coef"] = vals.fillna(0.0)
    df = df.drop_duplicates("term", keep="last")
    if "Intercept" not in set(df["term"]):
        df = pd.concat([pd.DataFrame([{"term":"Intercept","coef":0.0}]), df], ignore_index=True)
    return dict(zip(df["term"], df["coef"]))

# ---------- 坐标读取 ----------
def _try_load_coords_from_mpa_excels(search_dirs: List[Path]) -> Optional[pd.DataFrame]:
    years = [2015, 2020, 2025]
    rows = []
    for d in search_dirs:
        for y in years:
            f = d / f"{y}MPA.xlsx"
            if f.exists():
                try:
                    tmp = pd.read_excel(f)
                    id_col = _detect_id_col(tmp) or "MPA_ID"
                    lonc, latc = _detect_coord_cols(tmp)
                    if lonc and latc:
                        sub = tmp[[id_col, lonc, latc]].copy()
                        sub.columns = ["MPA_ID_COORD", COL_LON, COL_LAT]
                        rows.append(sub)
                except Exception:
                    pass
    if not rows:
        return None
    allc = pd.concat(rows, ignore_index=True)
    allc = allc.dropna(subset=[COL_LON, COL_LAT])
    allc = allc.drop_duplicates("MPA_ID_COORD", keep="first")
    return allc[["MPA_ID_COORD", COL_LON, COL_LAT]].copy()

def load_coords(coord_csv: Path, panel_dir: Path) -> pd.DataFrame:
    if not coord_csv.exists():
        raise FileNotFoundError(coord_csv)
    cd = pd.read_csv(coord_csv)
    idc = _detect_id_col(cd)
    if idc is None:
        raise KeyError("坐标表未检测到 MPA ID 列（如 MPA_ID / id）")
    lonc, latc = _detect_coord_cols(cd)

    if lonc is None or latc is None:
        print("[coords] 未发现 lon/lat，尝试从 2015/2020/2025MPA.xlsx 抽取坐标 ...")
        fallback = _try_load_coords_from_mpa_excels([
            panel_dir, panel_dir.parent, Path.cwd(),
            Path(r"D:\Fragmentation of MPA governance\mpa_network_misfit\pythonProject1\.venv")
        ])
        if fallback is None:
            raise KeyError("无法从坐标表或 MPA.xlsx 中获得经纬度列。")
        return fallback

    cd = cd.rename(columns={idc:"MPA_ID_COORD", lonc:COL_LON, latc:COL_LAT})
    keep = ["MPA_ID_COORD", COL_LON, COL_LAT] + ([COL_YEAR] if COL_YEAR in cd.columns else [])
    return cd[keep].copy()

# ---------- 面板读取 + 别名处理 ----------
def load_panel_and_harmonize(p: Path) -> pd.DataFrame:
    if not p.exists():
        raise FileNotFoundError(p)
    df = pd.read_csv(p)

    if DEBUG_PRINT_COLUMNS:
        print("[harm] 面板列名预览：")
        print(list(df.columns))

    ycol = _first_match(df.columns, YEAR_ALIASES, ["^year$"])
    if ycol and ycol != COL_YEAR:
        df = df.rename(columns={ycol: COL_YEAR})
    scol = _first_match(df.columns, SEA_ALIASES, ["sea.*region"])
    if scol and scol != COL_SEA:
        df = df.rename(columns={scol: COL_SEA})

    if COL_YEAR in df.columns:
        df = df[df[COL_YEAR] == FILTER_YEAR].copy()
    else:
        print("[warn] 面板缺少年份列，已跳过按年份过滤。")

    misf = _first_match(df.columns, MISFIT_ALIASES)
    if FILTER_TO_MISFIT:
        if misf:
            if misf != COL_MISF:
                df = df.rename(columns={misf: COL_MISF})
            df = df[(df[COL_MISF] == 1) | (df[COL_MISF] == True)].copy()
        else:
            yflag = _first_match(df.columns, Y_ALIASES)
            if yflag:
                if yflag != Y_COL:
                    df = df.rename(columns={yflag: Y_COL})
                df[COL_MISF] = (pd.to_numeric(df[Y_COL], errors="coerce") >= 1).astype(int)
                df = df[df[COL_MISF] == 1].copy()
                print(f"[info] 未找到 {COL_MISF}，改用 {Y_COL} 作为错配筛选。")
            else:
                print("[warn] 未找到错配标记列（high_misfit_t1 / Y_t），已跳过错配筛选。")

    # SLAG
    if not FORCE_SLAG_FROM_Y:
        slag_col = _first_match(df.columns, SLAG_ALIASES,
                                ["spaw.*lag.*misfit.*t1.*z","spw.*lag.*misfit.*t1.*z","lag.*mis.*t1.*z","lag.*Y.*t.*z"])
        if slag_col is None:
            slag_raw = _first_match(df.columns, SLAG_RAW_CAND,
                                    ["spaw.*lag.*misfit.*t1","spw.*lag.*misfit.*t1","lag.*mis.*t1","lag.*Y.*t"])
            if slag_raw:
                df[COL_SLAG] = _zscore(df[slag_raw])
                print(f"[harm] 未找到 {COL_SLAG}，已从原始列 {slag_raw} Z标准化得到。")
            elif STRICT_FAIL_SLAG:
                raise KeyError(f"未找到空间滞后错配列（如 {COL_SLAG} 或常见原始列）。")
    if FORCE_SLAG_FROM_Y or (COL_SLAG not in df.columns):
        ycand = _first_match(df.columns, Y_ALIASES, ["^Y_t$", r"\bY\b","label","target"])
        if ycand is None:
            raise KeyError("无法用 Y_t 兜底：未在面板中找到 Y_t（或其别名）列。")
        df[COL_SLAG] = _zscore(df[ycand])
        print(f"[surrogate] 已用 {ycand} 的 Z 值作为临时 {COL_SLAG}。")

    # CORR
    corr = _first_match(df.columns, CORR_ALIASES, ["corridor.*betw.*t1.*z","corridor.*btw.*t1.*z","corr.*btw.*t1.*z"])
    if corr is None:
        corr_raw = _first_match(df.columns, CORR_RAW_CAND, ["corridor.*betw.*t1","corridor.*btw.*t1","corr.*btw.*t1"])
        if corr_raw:
            df[COL_CORR] = _zscore(df[corr_raw])
            print(f"[harm] 未找到 {COL_CORR}，已从原始列 {corr_raw} Z标准化得到。")
        else:
            cands = [c for c in CORR_SURROGATES if c in df.columns]
            if cands:
                df[COL_CORR] = _zscore(df[cands[0]])
                print(f"[surrogate] 未找到 corridor 指标，用 {cands[0]} 的 Z 值近似作为 {COL_CORR}。")
            else:
                df[COL_CORR] = 0.0
                print("[warn] 未找到任何走廊性指标（corridor/SC_gate/SC_deg），已用常数 0 兜底。")
    elif corr != COL_CORR:
        df = df.rename(columns={corr: COL_CORR})

    # LOOI（可选）
    looi = _first_match(df.columns, ["z_prov_count_looi_t1_z","prov_count_looi_t1_z","looi_t1_z","looi_z"], ["looi.*t1.*z"])
    if looi is None:
        looi_raw = _first_match(df.columns, ["prov_count_looi_t1","looi_t1","looi","prov_count_looi"], ["looi","prov.*looi.*t1"])
        if looi_raw:
            df[COL_LOOI] = _zscore(df[looi_raw])
            print(f"[harm] 未找到 {COL_LOOI}，已从原始列 {looi_raw} Z标准化得到。")
        else:
            print(f"[info] 未找到 {COL_LOOI}（或其原始列），S4 仅调整 SLAG/CORR。")
    elif looi != COL_LOOI:
        df = df.rename(columns={looi: COL_LOOI})

    if COL_SEA not in df.columns:
        raise KeyError("未找到海区列（sea_region / SeaRegion / sea_region_std 等）。")

    if DEBUG_PRINT_COLUMNS:
        picked = {"SEA": COL_SEA,"YEAR": COL_YEAR if COL_YEAR in df.columns else "(missing)",
                  "MISFIT": COL_MISF if COL_MISF in df.columns else "(derived/skip)",
                  "SLAG": COL_SLAG,"CORRIDOR": COL_CORR,
                  "LOOI(optional)": COL_LOOI if COL_LOOI in df.columns else "(none)"}
        print("[harm] 关键列映射：", picked)

    return df.reset_index(drop=True)

# ---------- 坐标合并 ----------
def join_coords(df: pd.DataFrame, coord_df: pd.DataFrame) -> Tuple[pd.DataFrame, str]:
    id_col = _detect_id_col(df)
    if id_col is None:
        raise KeyError("面板表未检测到 MPA ID 列。")
    d = df.rename(columns={id_col:"MPA_ID_JOIN"}).copy()
    cd = coord_df.copy()

    if COL_YEAR in d.columns and COL_YEAR in cd.columns:
        m = d.merge(cd, left_on=["MPA_ID_JOIN", COL_YEAR], right_on=["MPA_ID_COORD", COL_YEAR], how="left")
    else:
        cd_first = cd.drop_duplicates("MPA_ID_COORD", keep="first")
        m = d.merge(cd_first, left_on="MPA_ID_JOIN", right_on="MPA_ID_COORD", how="left")

    if COL_LON not in m.columns: m[COL_LON] = np.nan
    if COL_LAT not in m.columns: m[COL_LAT] = np.nan

    miss = m[COL_LON].isna() | m[COL_LAT].isna()
    if miss.any():
        cd_first = cd.drop_duplicates("MPA_ID_COORD", keep="first")
        back = m.loc[miss, ["MPA_ID_JOIN"]].merge(cd_first, left_on="MPA_ID_JOIN", right_on="MPA_ID_COORD", how="left")[[COL_LON, COL_LAT]]
        m.loc[miss, COL_LON] = back[COL_LON].values
        m.loc[miss, COL_LAT] = back[COL_LAT].values

    m = m.rename(columns={"MPA_ID_JOIN": id_col})
    n_miss = int(m[COL_LON].isna().sum() + m[COL_LAT].isna().sum())
    if n_miss > 0:
        print(f"[coords] 警告：仍有 {n_miss//2} 条记录缺少坐标（图上会忽略它们）。")
    return m, id_col

# ---------- 线性预测器（含 FE） ----------
_FE_SEA_RE  = re.compile(r"^C\(sea_region\)\[T\.(.+)\]$")
_FE_YEAR_RE = re.compile(r"^C\(year\)\[T\.(.+)\]$")

def linear_predict(row: pd.Series, coefs: Dict[str,float]) -> float:
    s = coefs.get("Intercept", 0.0)
    for col, b in coefs.items():
        if col.startswith("C(") or col=="Intercept": continue
        if col in row.index:
            val = row[col]
            if pd.isna(val): val = 0.0
            s += b * float(val)
    sea = str(row.get(COL_SEA, ""))
    yr  = str(row.get(COL_YEAR, ""))
    for k, b in coefs.items():
        m = _FE_SEA_RE.match(k)
        if m and sea == m.group(1): s += b
        m = _FE_YEAR_RE.match(k)
        if m and yr  == m.group(1): s += b
    return float(s)

def apply_clip(x): return float(np.clip(x, CLIP_MIN, CLIP_MAX))

# ---------- 自动标定Δ ----------
def _median_slope(p):  # logit 的边际斜率 ~ p*(1-p)
    return float(np.median(p*(1-p)))

def calibrate_deltas(df_plot, coefs, p0,
                     q_s1=0.50, q_s2=0.25,
                     target_dp = {"S1":0.05, "S2":0.05, "S3":0.07, "S4":0.14},
                     S3_DAMP = 0.85):
    b_slag = coefs.get(COL_SLAG, None)
    b_corr = coefs.get(COL_CORR, None)
    slope  = _median_slope(p0)

    out = {}

    g1 = df_plot.groupby(COL_SEA)[COL_SLAG].transform(lambda s: s.quantile(q_s1))
    over = (df_plot[COL_SLAG] - g1).clip(lower=0)
    qref1 = float(np.quantile(over[over>0], 0.75)) if (over>0).any() else 0.0

    g2 = df_plot.groupby(COL_SEA)[COL_CORR].transform(lambda s: s.quantile(q_s2))
    gap = (g2 - df_plot[COL_CORR]).clip(lower=0)
    qref2 = float(np.quantile(gap[gap>0], 0.75)) if (gap>0).any() else 0.0

    if b_slag and qref1>0:
        dlogit_s1 = target_dp["S1"]/slope
        out["K1"] = dlogit_s1 / (abs(b_slag)*qref1)
    if b_corr and qref2>0:
        dlogit_s2 = target_dp["S2"]/slope
        out["K2"] = dlogit_s2 / (abs(b_corr)*qref2)

    if "K1" in out: out["K1_S3"] = S3_DAMP * out["K1"] * (target_dp["S3"]/target_dp["S1"])
    if "K2" in out: out["K2_S3"] = S3_DAMP * out["K2"] * (target_dp["S3"]/target_dp["S2"])

    if b_slag or b_corr:
        w1 = abs(b_slag) if b_slag else 0.0
        w2 = abs(b_corr) if b_corr else 0.0
        if (w1+w2) > 0:
            dlogit_s4 = target_dp["S4"]/slope
            dlogit_s4_slag = dlogit_s4 * (w1/(w1+w2))
            dlogit_s4_corr = dlogit_s4 * (w2/(w1+w2))
            out["DELTA_S4_SLAG"] = -np.sign(b_slag)*dlogit_s4_slag/abs(b_slag) if b_slag else 0.0
            out["DELTA_S4_CORR"] =  np.sign(b_corr)*dlogit_s4_corr/abs(b_corr) if b_corr else 0.0

    print("[calib] targets:", target_dp, "| slope≈", round(slope,4))
    print("[calib] b_slag=", b_slag, "b_corr=", b_corr)
    print("[calib] K1,K2:", {k:round(v,3) for k,v in out.items() if k in ("K1","K2")})
    print("[calib] K1_S3,K2_S3:", {k:round(v,3) for k,v in out.items() if "S3" in k})
    print("[calib] S4 deltas:", {k:round(v,3) for k,v in out.items() if "S4" in k})
    return out

# ---------- 情景构造 ----------
def _quant_by_group(df, col, q):
    gq = df.groupby(COL_SEA)[col].transform(lambda s: s.quantile(q))
    if gq.isna().any():
        gq = gq.fillna(df[col].quantile(q))
    return gq

def scenario_S1(df: pd.DataFrame, params: dict) -> pd.DataFrame:
    d = df.copy()
    K1 = params.get("K1", None)
    q = _quant_by_group(d, COL_SLAG, S1_TOP_Q)
    over = (d[COL_SLAG] - q).clip(lower=0)
    if K1 is None:
        mask = d[COL_SLAG] >= q
        d.loc[mask, COL_SLAG] = d.loc[mask, COL_SLAG].apply(lambda v: apply_clip(v + DELTA_S1))
    else:
        d[COL_SLAG] = (d[COL_SLAG] - K1 * over).clip(CLIP_MIN, CLIP_MAX)
    return d

def scenario_S2(df: pd.DataFrame, params: dict) -> pd.DataFrame:
    d = df.copy()
    K2 = params.get("K2", None)
    q = _quant_by_group(d, COL_CORR, 0.25)
    gap = (q - d[COL_CORR]).clip(lower=0)
    if K2 is None:
        mask = d[COL_CORR] <= q
        d.loc[mask, COL_CORR] = d.loc[mask, COL_CORR].apply(lambda v: apply_clip(v + DELTA_S2))
    else:
        d[COL_CORR] = (d[COL_CORR] + K2 * gap).clip(CLIP_MIN, CLIP_MAX)
    return d

def scenario_S3(df: pd.DataFrame, params: dict) -> pd.DataFrame:
    d = df.copy()
    K1 = params.get("K1_S3", params.get("K1", None))
    K2 = params.get("K2_S3", params.get("K2", None))
    q1 = _quant_by_group(d, COL_SLAG, S1_TOP_Q)
    q2 = _quant_by_group(d, COL_CORR, 0.25)
    over = (d[COL_SLAG] - q1).clip(lower=0)
    gap  = (q2 - d[COL_CORR]).clip(lower=0)
    if K1 is None:
        d.loc[d[COL_SLAG] >= q1, COL_SLAG] = d.loc[d[COL_SLAG] >= q1, COL_SLAG].apply(lambda v: apply_clip(v + DELTA_S1))
    else:
        d[COL_SLAG] = d[COL_SLAG] - K1 * over
    if K2 is None:
        d.loc[d[COL_CORR] <= q2, COL_CORR] = d.loc[d[COL_CORR] <= q2, COL_CORR].apply(lambda v: apply_clip(v + DELTA_S2))
    else:
        d[COL_CORR] = d[COL_CORR] + K2 * gap
    d[COL_SLAG] = d[COL_SLAG].clip(CLIP_MIN, CLIP_MAX)
    d[COL_CORR] = d[COL_CORR].clip(CLIP_MIN, CLIP_MAX)
    return d

def scenario_S4(df: pd.DataFrame, params: dict) -> pd.DataFrame:
    d = df.copy()
    ds = params.get("DELTA_S4_SLAG", None)
    dc = params.get("DELTA_S4_CORR", None)
    if ds is None: ds = DELTA_S1
    if dc is None: dc = DELTA_S2
    d[COL_SLAG] = (d[COL_SLAG] + ds).clip(CLIP_MIN, CLIP_MAX)
    d[COL_CORR] = (d[COL_CORR] + dc).clip(CLIP_MIN, CLIP_MAX)
    if COL_LOOI in d.columns:
        d[COL_LOOI] = d[COL_LOOI].apply(lambda v: apply_clip(v + DELTA_LOOI))
    return d

# ================== 主流程 ==================
print(f"[scenario] panel: {PANEL_CSV}")
print(f"[scenario] coefs: {COEF_CSV}")
print(f"[scenario] coords: {COORD_CSV}")

panel = load_panel_and_harmonize(PANEL_CSV)
coefs  = parse_coefs(COEF_CSV)
coords = load_coords(COORD_CSV, PANEL_CSV.parent)

df, id_col = join_coords(panel, coords)
df_plot = df.dropna(subset=[COL_LON, COL_LAT]).copy()

need_cols = [COL_SEA, COL_SLAG, COL_CORR]
missing = [c for c in need_cols if c not in df_plot.columns]
if missing:
    raise KeyError(f"关键列缺失：{missing}。")

# 基线预测
lp0 = df_plot.apply(lambda r: linear_predict(r, coefs), axis=1)
p0  = invlogit(lp0.values)

# 自动标定（降低S3、抬升S4）
params = calibrate_deltas(df_plot, coefs, p0,
                          q_s1=S1_TOP_Q, q_s2=0.25,
                          target_dp={"S1":0.05,"S2":0.05,"S3":0.07,"S4":0.16},
                          S3_DAMP=0.75)

# 四个情景
S1 = scenario_S1(df_plot, params)
S2 = scenario_S2(df_plot, params)
S3 = scenario_S3(df_plot, params)
S4 = scenario_S4(df_plot, params)

p1 = invlogit(S1.apply(lambda r: linear_predict(r, coefs), axis=1).values)
p2 = invlogit(S2.apply(lambda r: linear_predict(r, coefs), axis=1).values)
p3 = invlogit(S3.apply(lambda r: linear_predict(r, coefs), axis=1).values)
p4 = invlogit(S4.apply(lambda r: linear_predict(r, coefs), axis=1).values)

imp1 = p0 - p1
imp2 = p0 - p2
imp3 = p0 - p3
imp4 = p0 - p4

# ================== 导出 ==================
node_out = df_plot[[id_col, COL_SEA, COL_LON, COL_LAT]].copy()
node_out["p_baseline"] = p0
node_out["imp_S1"] = imp1
node_out["imp_S2"] = imp2
node_out["imp_S3"] = imp3
node_out["imp_S4"] = imp4
node_out.to_csv(OUT_DIR/"scenario_node_improvements.csv", index=False, encoding="utf-8-sig")

sum_by_sea = node_out.groupby(COL_SEA)[["imp_S1","imp_S2","imp_S3","imp_S4"]].mean().reset_index()
sum_by_sea.to_csv(OUT_DIR/"scenario_summary_by_sea.csv", index=False, encoding="utf-8-sig")
sum_overall = node_out[["imp_S1","imp_S2","imp_S3","imp_S4"]].mean().to_frame("avg_improvement")
sum_overall.to_csv(OUT_DIR/"scenario_summary_overall.csv", encoding="utf-8-sig")

TOPN = 30
top4 = node_out.sort_values("imp_S4", ascending=False).head(TOPN)
top4.to_csv(OUT_DIR/"top_improved_nodes.csv", index=False, encoding="utf-8-sig")

# targeting 诊断
def _share_and_p90pos(arr):
    arr = np.asarray(arr)
    share = float((arr>0).mean())
    p90   = float(np.quantile(arr[arr>0], 0.90)) if (arr>0).any() else 0.0
    return share, p90
rows = []
for k, vec in [("S1",imp1),("S2",imp2),("S3",imp3),("S4",imp4)]:
    share, p90pos = _share_and_p90pos(vec)
    rows.append({"scenario":k, "mean":float(np.mean(vec)), "share_positive":share, "p90_positive":p90pos})
pd.DataFrame(rows).to_csv(OUT_DIR/"scenario_summary_targeting.csv", index=False, encoding="utf-8-sig")

# ================== 可视化（自适应色标与气泡；Baseline 左列跨两行） ==================
def _vmax_auto(*imps, q=0.95, floor=0.01):
    arr = np.concatenate([np.abs(x) for x in imps], axis=0)
    vmax = float(np.quantile(arr, q)) if arr.size else floor
    return max(vmax, floor)

def _kbubble_auto(imp, base=30.0, target_area=450.0):
    pos = imp[imp>0]
    p90 = float(np.quantile(pos, 0.90)) if pos.size else 0.0
    if p90 <= 1e-6:
        return 2000.0
    k = (target_area - base) / p90
    return float(np.clip(k, 800.0, 10000.0))

VMAX = _vmax_auto(imp1, imp2, imp3, imp4, q=0.95, floor=0.01)
VMIN = -VMAX
K_BUBBLE_DRAW = _kbubble_auto(imp4, base=BASE_SIZE, target_area=450.0)

fig = plt.figure(figsize=(14, 9))
gs = fig.add_gridspec(2, 3, width_ratios=[1, 1, 1], height_ratios=[1, 1], wspace=0.25, hspace=0.22)

# Baseline 跨两行
ax0 = fig.add_subplot(gs[:,0])   # 左列占两行
ax1 = fig.add_subplot(gs[0,1])   # 右上：S1
ax2 = fig.add_subplot(gs[0,2])   # 右上：S2
ax3 = fig.add_subplot(gs[1,1])   # 右下：S3
ax4 = fig.add_subplot(gs[1,2])   # 右下：S4

# 基线
ax0.scatter(df_plot[COL_LON], df_plot[COL_LAT], s=10, color="#d0d0d0", edgecolor="none", alpha=0.35)
ax0.scatter(df_plot[COL_LON], df_plot[COL_LAT],
            s=BASE_SIZE, color=BASELINE_NODE_COLOR, edgecolor=BASELINE_EDGE_COLOR, lw=0.5, alpha=0.85,
            label=f"{FILTER_YEAR} misfit (baseline)")
ax0.set_title(f"Baseline — {FILTER_YEAR} Misfit Nodes")
ax0.set_xlabel("lon"); ax0.set_ylabel("lat")
ax0.legend(loc="lower left", frameon=True)

def draw(ax, title, improvement):
    sc = ax.scatter(
        df_plot[COL_LON], df_plot[COL_LAT],
        s=BASE_SIZE + K_BUBBLE_DRAW * np.maximum(0, improvement),
        c=improvement, cmap="coolwarm", vmin=VMIN, vmax=VMAX,
        edgecolor="white", lw=0.4, alpha=0.9
    )
    ax.scatter(df_plot[COL_LON], df_plot[COL_LAT], s=8, color="#cfcfd4", alpha=0.25, edgecolor="none")
    pos_share = float((improvement>0).mean())
    mean_imp  = float(np.mean(improvement))
    ax.text(0.02, 0.98, f"+%={pos_share:.0%}\nμ={mean_imp:.3f}",
            ha="left", va="top", transform=ax.transAxes,
            fontsize=9, color="#444", bbox=dict(boxstyle="round,pad=0.2", fc="white", ec="#ddd"))
    ax.set_title(title); ax.set_xlabel("lon"); ax.set_ylabel("lat")
    return sc

sc1 = draw(ax1, "S1 — Sea-region Governance", imp1)
sc2 = draw(ax2, "S2 — Corridor Governance",   imp2)
sc3 = draw(ax3, "S3 — Targeted Governance (damped)", imp3)
sc4 = draw(ax4, "S4 — Unified Governance",    imp4)

# 统一色标
cax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
cb = fig.colorbar(sc4, cax=cax)
cb.set_label("Improvement (baseline − scenario)")

fig.suptitle("Scenario Simulation of Spatial–Institutional Misfit", fontsize=18, weight="bold")
out_png = FIG_DIR/"map_composite.png"
fig.savefig(out_png, dpi=150, bbox_inches="tight")
plt.close(fig)

print("[scenario] saved map:", out_png)
print("[scenario] done.")
